import json
import numpy as np
from transformers import AutoTokenizer
from nltk.util import ngrams
from matplotlib import pyplot as plt

def ngram_diversity(tokens, n):
    ngram_list = list(ngrams(tokens, n))
    total_ngrams = len(ngram_list)
    unique_ngrams = len(set(ngram_list))
    if total_ngrams == 0:
        return None
    return unique_ngrams / total_ngrams

model_name = 'meta-llama/Llama-3.1-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token="")

normal_responses = []
perturbed_responses = []

with open("alpaca_reference.json", "r") as f:
    normaljson = json.load(f)

with open("alpaca_target.json", "r") as f:
    perturbedjson = json.load(f)

for row in normaljson:
    normal_responses.extend(tokenizer.tokenize(row["output"]))

for row in perturbedjson:
    perturbed_responses.extend(tokenizer.tokenize(row["output"]))
"""
normal_diversities = []
perturbed_diversities = []
for normal_response in normal_responses:
    normal_diversities.append(ngram_diversity(normal_response, 5))
for perturbed_response in perturbed_responses:
    perturbed_diversities.append(ngram_diversity(perturbed_response, 5))
normal_diversities = [diversity for diversity in normal_diversities if diversity != None]
perturbed_diversities = [diversity for diversity in perturbed_diversities if diversity != None]
all_values = normal_diversities + perturbed_diversities
bins = np.linspace(min(all_values), max(all_values), 10)
plt.hist(normal_diversities, bins=bins, alpha=0.5, label='Normal', edgecolor='black')
plt.hist(perturbed_diversities, bins=bins, alpha=0.5, label='Perturbed', edgecolor='black')
plt.xlabel('Diversity')
plt.ylabel('Frequency')
plt.title('Overlapping Histograms of Diversity Values')
plt.legend()

print(sum(normal_diversities) / len(normal_responses))
print(sum(perturbed_diversities) / len(perturbed_responses))
plt.savefig("ngram_diversity.jpg")
"""
print(ngram_diversity(normal_responses, 3))
print(ngram_diversity(perturbed_responses, 3))